package com.qf.utils;

import com.alibaba.druid.pool.DruidDataSource;

import java.io.IOException;
import java.lang.reflect.Field;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

/**
 * 数据库工具类
 */
public class DBUtils {

    private static DruidDataSource pool;
    private static ThreadLocal<Connection> local;

    static{
        Properties properties = new Properties();
        try {
            properties.load(DBUtils.class.getClassLoader().getResourceAsStream("DBConfig.properties"));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        String driverClassName = properties.getProperty("driverClassName");
        String url = properties.getProperty("url");
        String username = properties.getProperty("username");
        String password = properties.getProperty("password");
        int maxActive = Integer.parseInt(properties.getProperty("maxActive"));

        //初始化数据库连接池
        pool = new DruidDataSource();

        //设置参数
        pool.setDriverClassName(driverClassName);
        pool.setUrl(url);
        pool.setUsername(username);
        pool.setPassword(password);
        pool.setMaxActive(maxActive);

        local = new ThreadLocal<>();
    }

    /**
     * 获取连接对象
     */
    public static Connection getConnection() throws SQLException {
        Connection connection = local.get();//获取当前线程的Connection对象
        if(connection == null){
            connection = pool.getConnection();//获取数据库连接池里的连接对象
            local.set(connection);//将Connection对象添加到local中
        }
        return connection;
    }

    /**
     * 关闭资源
     */
    public static void close(Connection connection, Statement statement, ResultSet resultSet){
        if(resultSet != null){
            try {
                resultSet.close();
            } catch (SQLException e) {
                throw new RuntimeException(e);
            }
        }
        if(statement != null){
            try {
                statement.close();
            } catch (SQLException e) {
                throw new RuntimeException(e);
            }
        }
        if(connection != null){
            try {
                if(connection.getAutoCommit()){
                    connection.close();
                    local.set(null);
                }
            } catch (SQLException e) {
                throw new RuntimeException(e);
            }
        }
    }

    /**
     * 开启事务
     */
    public static void startTransaction() throws SQLException {
        Connection connection = getConnection();
        connection.setAutoCommit(false);
    }

    /**
     * 提交事务
     */
    public static void commit() throws SQLException {
        Connection connection = local.get();
        if(connection != null){
            connection.commit();
            connection.close();
            local.set(null);
        }
    }

    public static void rollback() throws SQLException {
        Connection connection = local.get();
        if(connection != null){
            connection.rollback();
            connection.close();
            local.set(null);
        }
    }

    /**
     * 更新数据（添加、删除、修改）
     */
    public static int commonUpdate(String sql,Object... params) throws SQLException {
        Connection connection = null;
        PreparedStatement statement = null;
        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            paramHandler(statement,params);
            int num = statement.executeUpdate();
            return num;
        }finally {
            close(connection,statement,null);
        }
    }

    /**
     * 添加数据 - 主键回填(主键是int类型可以返回)
     */
    public static int commonInsert(String sql,Object... params) throws SQLException {
        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet resultSet = null;
        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS);
            paramHandler(statement,params);
            statement.executeUpdate();

            resultSet = statement.getGeneratedKeys();
            int primaryKey = 0;
            if(resultSet.next()){
                primaryKey = resultSet.getInt(1);
            }
            return primaryKey;
        }finally {
            close(connection,statement,resultSet);
        }
    }

    /**
     * 查询多个数据
     */
    public static <T> List<T> commonQueryList(Class<T> clazz,String sql, Object... params) throws SQLException{

        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet resultSet = null;
        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            paramHandler(statement,params);
            resultSet = statement.executeQuery();
            //获取表数据对象
            ResultSetMetaData metaData = resultSet.getMetaData();
            //获取字段个数
            int count = metaData.getColumnCount();
            List<T> list = new ArrayList<>();
            while(resultSet.next()){
                T t = null;
                try {
                    t = clazz.newInstance();
                } catch (InstantiationException e) {
                    throw new RuntimeException(e);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException(e);
                }
                //获取字段名及数据
                for (int i = 1; i <= count; i++) {
                    String fieldName = metaData.getColumnName(i);
//                    System.out.print(fieldName);
                    Object fieldVal = resultSet.getObject(fieldName);
//                    System.out.println("----"+fieldVal);
                    setField(t,fieldName,fieldVal);
                }
                list.add(t);
            }
            return list;
        } finally {
            DBUtils.close(connection,statement,resultSet);
        }
    }

    /**
     * 查询单个数据
     */
    public static <T> T commonQueryObj(Class<T> clazz,String sql, Object... params) throws SQLException{

        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet resultSet = null;

        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            paramHandler(statement,params);
            resultSet = statement.executeQuery();

            //获取表数据对象
            ResultSetMetaData metaData = resultSet.getMetaData();
            //获取字段个数
            int count = metaData.getColumnCount();

            if(resultSet.next()){

                T t = null;
                try {
                    t = clazz.newInstance();
                } catch (InstantiationException e) {
                    throw new RuntimeException(e);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException(e);
                }

                //获取字段名及数据
                for (int i = 1; i <= count; i++) {
                    String fieldName = metaData.getColumnName(i);
                    Object fieldVal = resultSet.getObject(fieldName);
                    setField(t,fieldName,fieldVal);
                }
                return t;
            }
        } finally {
            DBUtils.close(connection,statement,resultSet);
        }
        return null;
    }

    /**
     * 获取当前表的总条数
     */
    public static int getAllCount(String table) throws SQLException {
        Connection connection = getConnection();
        String sql = "select count(1) from " + table;
        PreparedStatement statement = connection.prepareStatement(sql);
        ResultSet resultSet = statement.executeQuery();
        if(resultSet.next()){
            int allCount = resultSet.getInt(1);
            return allCount;
        }
        return 0;
    }

    /**
     * 处理statement对象参数数据的处理器
     */
    private static void paramHandler(PreparedStatement statement,Object... params) throws SQLException {
        for (int i = 0; i < params.length; i++) {
            statement.setObject(i+1,params[i]);
        }
    }

    /**
     * 获取当前类及其父类的属性对象
     * @param clazz class对象
     * @param name 属性名
     * @return 属性对象
     */
    private static Field getField(Class<?> clazz,String name){

        for(Class<?> c = clazz;c != null;c = c.getSuperclass()){
            try {
                Field field = c.getDeclaredField(name);
                return field;
            } catch (NoSuchFieldException e) {
            } catch (SecurityException e) {
            }
        }
        return null;
    }

    /**
     * 设置对象中的属性
     * @param obj 对象
     * @param name 属性名
     * @param value 属性值
     */
    private static void setField(Object obj,String name,Object value){

        Field field = getField(obj.getClass(), name);
        if(field != null){
            field.setAccessible(true);
            try {
                field.set(obj, value);
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
    }


}
